# SPDX-License-Identifier: Apache-2.0
"""
cache-dit integration module for SGLang DiT pipelines.

This module provides helper functions to enable cache-dit acceleration
on transformer modules in SGLang's modular pipeline architecture.
"""

from dataclasses import dataclass
from typing import List, Optional

import torch

from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)

import cache_dit
from cache_dit import (
    BlockAdapter,
    DBCacheConfig,
    ForwardPattern,
    ParamsModifier,
    TaylorSeerCalibratorConfig,
    steps_mask,
)
from cache_dit.caching.block_adapters import BlockAdapterRegister


def get_scm_mask(
    preset: str,
    num_inference_steps: int,
    compute_bins: Optional[List[int]] = None,
    cache_bins: Optional[List[int]] = None,
) -> Optional[List[int]]:
    """
    Get SCM mask using cache-dit's steps_mask().

    This is a thin wrapper that delegates to cache-dit's built-in
    steps_mask() function which handles all presets and scaling logic.

    Args:
        preset: Preset name ("none", "slow", "medium", "fast", "ultra").
        compute_bins: Custom compute bins (overrides preset).
        cache_bins: Custom cache bins (overrides preset).

    Returns:
        SCM mask list (1=compute, 0=cache), or None if disabled.
    """
    if preset == "none" and not (compute_bins and cache_bins):
        return None

    # Use cache-dit's steps_mask() directly
    mask = steps_mask(
        compute_bins=compute_bins,
        cache_bins=cache_bins,
        total_steps=num_inference_steps,
        mask_policy=preset if preset != "none" else "medium",
    )

    compute_count = sum(mask)
    cache_count = len(mask) - compute_count
    logger.info(
        "SCM: generated mask with %d compute steps, %d cache steps (preset=%s)",
        compute_count,
        cache_count,
        preset,
    )

    return mask


@dataclass
class CacheDitConfig:
    """Configuration for cache-dit integration.

    Attributes:
        enabled: Whether to enable cache-dit acceleration.
        Fn_compute_blocks: Number of first blocks to always compute (DBCache F).
        Bn_compute_blocks: Number of last blocks to always compute (DBCache B).
        max_warmup_steps: Number of warmup steps before caching starts (DBCache W).
        residual_diff_threshold: Threshold for residual difference (DBCache R).
        max_continuous_cached_steps: Maximum consecutive cached steps (DBCache MC).
        enable_taylorseer: Whether to enable TaylorSeer calibrator.
        taylorseer_order: Order of Taylor expansion (1 or 2).
        num_inference_steps: Total number of inference steps (required for transformer-only mode).
        steps_computation_mask: Binary mask for step-level caching (1=compute, 0=cache).
            Generated by get_scm_mask() (wrapper around cache_dit.steps_mask()).
        steps_computation_policy: Caching policy for SCM ("dynamic" or "static").
    """

    enabled: bool = False
    Fn_compute_blocks: int = 1
    Bn_compute_blocks: int = 0
    # Use 4 as default warmup steps instead of 8 in cache-dit, thus making
    # DBCache work for few steps distilled models, e.g., Z-Image w/ 8-steps.
    max_warmup_steps: int = 4
    # Use a relatively higher residual diff threshold (namely, 0.24) as default
    # to allow more aggressive caching due to we have already applied max continuous
    # cached steps limit, otherwise, we should use a lower threshold here like 0.12.
    residual_diff_threshold: float = 0.24
    max_continuous_cached_steps: int = 3
    # TaylorSeer is not suitable for few steps distilled models, so, we choose
    # to disable it by default. Reference:
    # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers,
    #   https://arxiv.org/pdf/2503.06923
    # - FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient
    #   Diffusion Transformers, https://arxiv.org/pdf/2508.16211
    enable_taylorseer: bool = False
    taylorseer_order: int = 1
    num_inference_steps: Optional[int] = None
    # SCM fields (generated by _maybe_enable_cache_dit from env configuration)
    steps_computation_mask: Optional[List[int]] = None
    steps_computation_policy: str = "dynamic"


def enable_cache_on_transformer(
    transformer: torch.nn.Module,
    config: CacheDitConfig,
    model_name: str = "transformer",
) -> torch.nn.Module:
    """Enable cache-dit on a transformer module, by wrapping the module with cache-dit

    This function enables cache-dit acceleration using the BlockAdapterRegister
    for pre-registered models

    Args:
        model_name: Name of the model for logging purposes.

    """
    if not config.enabled:
        return transformer

    if config.num_inference_steps is None:
        raise ValueError(
            "num_inference_steps is required for transformer-only mode. "
            "Please provide it in CacheDitConfig."
        )

    # Check if the transformer is pre-registered in cache-dit
    if not BlockAdapterRegister.is_supported(transformer):
        transformer_cls_name = transformer.__class__.__name__
        raise ValueError(
            f"{transformer_cls_name} is not officially supported by cache-dit. "
            "Supported cache-dit DiT families include Flux, QwenImage, HunyuanDiT, "
            "HunyuanVideo, Wan, CogVideoX, Mochi, and others. "
            "Please ensure your transformer belongs to one of these families or "
            "define a custom BlockAdapter."
        )

    # Build cache config (including SCM fields if provided)
    cache_config = DBCacheConfig(
        num_inference_steps=config.num_inference_steps,
        Fn_compute_blocks=config.Fn_compute_blocks,
        Bn_compute_blocks=config.Bn_compute_blocks,
        max_warmup_steps=config.max_warmup_steps,
        residual_diff_threshold=config.residual_diff_threshold,
        max_continuous_cached_steps=config.max_continuous_cached_steps,
        # SCM fields
        steps_computation_mask=config.steps_computation_mask,
        steps_computation_policy=config.steps_computation_policy,
    )

    # Build calibrator config if TaylorSeer is enabled
    calibrator_config = None
    if config.enable_taylorseer:
        calibrator_config = TaylorSeerCalibratorConfig(
            taylorseer_order=config.taylorseer_order,
        )

    # Enable cache-dit on the transformer
    logger.info(
        "Enabling cache-dit on %s with config: Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, "
        "TaylorSeer=%s (order=%d), steps=%d",
        model_name,
        config.Fn_compute_blocks,
        config.Bn_compute_blocks,
        config.max_warmup_steps,
        config.residual_diff_threshold,
        config.max_continuous_cached_steps,
        config.enable_taylorseer,
        config.taylorseer_order,
        config.num_inference_steps,
    )

    # Log SCM configuration if enabled
    if config.steps_computation_mask:
        compute_steps = sum(config.steps_computation_mask)
        cache_steps = len(config.steps_computation_mask) - compute_steps
        logger.info(
            "SCM enabled: %d compute steps, %d cache steps, policy=%s",
            compute_steps,
            cache_steps,
            config.steps_computation_policy,
        )

    cache_dit.enable_cache(
        transformer,
        cache_config=cache_config,
        calibrator_config=calibrator_config,
    )

    return transformer


def enable_cache_on_dual_transformer(
    transformer: torch.nn.Module,
    transformer_2: torch.nn.Module,
    primary_config: CacheDitConfig,
    secondary_config: CacheDitConfig,
    model_name: str = "wan2.2",
) -> tuple[torch.nn.Module, torch.nn.Module]:
    """Enable cache-dit on dual transformers using BlockAdapter.

    For models with two transformers (high-noise expert and low-noise expert),
    cache-dit requires enabling cache on both simultaneously via BlockAdapter.
    This cannot be done by calling enable_cache separately on each transformer.

    Args:
        primary_config: CacheDitConfig for primary transformer.
        secondary_config: CacheDitConfig for secondary transformer.
    """
    _supported_dual_transformer_models = [
        "wan2.2",  # Currently, only Wan2.2 will run into dual-transformer case
    ]
    if model_name not in _supported_dual_transformer_models:
        raise ValueError(
            f"Dual-transformer cache-dit is only supported for "
            f"{_supported_dual_transformer_models}, got {model_name}."
        )

    if not primary_config.enabled:
        return transformer, transformer_2

    if primary_config.num_inference_steps is None:
        raise ValueError(
            "num_inference_steps is required for dual-transformer mode. "
            "Please provide it in CacheDitConfig."
        )

    # Build DBCacheConfig for primary transformer
    primary_cache_config = DBCacheConfig(
        num_inference_steps=primary_config.num_inference_steps,
        Fn_compute_blocks=primary_config.Fn_compute_blocks,
        Bn_compute_blocks=primary_config.Bn_compute_blocks,
        max_warmup_steps=primary_config.max_warmup_steps,
        residual_diff_threshold=primary_config.residual_diff_threshold,
        max_continuous_cached_steps=primary_config.max_continuous_cached_steps,
        steps_computation_mask=primary_config.steps_computation_mask,
        steps_computation_policy=primary_config.steps_computation_policy,
    )

    # Build DBCacheConfig for secondary transformer
    secondary_cache_config = DBCacheConfig(
        num_inference_steps=secondary_config.num_inference_steps,
        Fn_compute_blocks=secondary_config.Fn_compute_blocks,
        Bn_compute_blocks=secondary_config.Bn_compute_blocks,
        max_warmup_steps=secondary_config.max_warmup_steps,
        residual_diff_threshold=secondary_config.residual_diff_threshold,
        max_continuous_cached_steps=secondary_config.max_continuous_cached_steps,
        steps_computation_mask=secondary_config.steps_computation_mask,
        steps_computation_policy=secondary_config.steps_computation_policy,
    )

    # Build calibrator configs if TaylorSeer is enabled
    primary_calibrator = None
    if primary_config.enable_taylorseer:
        primary_calibrator = TaylorSeerCalibratorConfig(
            taylorseer_order=primary_config.taylorseer_order,
        )

    secondary_calibrator = None
    if secondary_config.enable_taylorseer:
        secondary_calibrator = TaylorSeerCalibratorConfig(
            taylorseer_order=secondary_config.taylorseer_order,
        )

    # Build ParamsModifier for each transformer
    primary_modifier = ParamsModifier(
        cache_config=primary_cache_config,
        calibrator_config=primary_calibrator,
    )
    secondary_modifier = ParamsModifier(
        cache_config=secondary_cache_config,
        calibrator_config=secondary_calibrator,
    )

    # Log configuration
    logger.info(
        "Enabling cache-dit on %s dual transformers with BlockAdapter",
        model_name,
    )
    logger.info(
        "  Primary (transformer): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s",
        primary_config.Fn_compute_blocks,
        primary_config.Bn_compute_blocks,
        primary_config.max_warmup_steps,
        primary_config.residual_diff_threshold,
        primary_config.max_continuous_cached_steps,
        primary_config.enable_taylorseer,
    )
    logger.info(
        "  Secondary (transformer_2): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s",
        secondary_config.Fn_compute_blocks,
        secondary_config.Bn_compute_blocks,
        secondary_config.max_warmup_steps,
        secondary_config.residual_diff_threshold,
        secondary_config.max_continuous_cached_steps,
        secondary_config.enable_taylorseer,
    )

    # Log SCM configuration if enabled
    if primary_config.steps_computation_mask:
        compute_steps = sum(primary_config.steps_computation_mask)
        cache_steps = len(primary_config.steps_computation_mask) - compute_steps
        logger.info(
            "  SCM enabled: %d compute steps, %d cache steps, policy=%s",
            compute_steps,
            cache_steps,
            primary_config.steps_computation_policy,
        )

    # Get blocks attribute - Wan transformers use 'blocks' attribute
    transformer_blocks = getattr(transformer, "blocks", None)
    transformer_2_blocks = getattr(transformer_2, "blocks", None)

    if transformer_blocks is None or transformer_2_blocks is None:
        raise ValueError(
            "Dual transformers must have 'blocks' attribute for cache-dit. "
            f"transformer has blocks: {transformer_blocks is not None}, "
            f"transformer_2 has blocks: {transformer_2_blocks is not None}"
        )

    # Enable cache-dit using BlockAdapter for both transformers simultaneously
    # This is required for Wan2.2 and similar dual-transformer architectures
    if model_name == "wan2.2":
        # Use Pattern_2 for Wan2.2 dual-transformer. We should check `model_name`
        # to ensure we only apply this for supported models. Different models
        # may require different ForwardPattern.
        cache_dit.enable_cache(
            BlockAdapter(
                transformer=[transformer, transformer_2],
                blocks=[transformer_blocks, transformer_2_blocks],
                forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
                params_modifiers=[primary_modifier, secondary_modifier],
                has_separate_cfg=True,
            ),
        )
    else:
        raise ValueError(
            f"Dual-transformer is not implemented for model {model_name} yet."
        )

    return transformer, transformer_2
